{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Device-local array layout control\n",
    "\n",
    "The `jax.experimental.layout` package provides ways to control\n",
    "how JAX arrays are laid out in device-local memory.\n",
    "\n",
    "## Terminology\n",
    "\n",
    "Array layout is tightly coupled with array\n",
    "[sharding](https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html>).\n",
    "Together, a layout and a sharding fully describes how an array's\n",
    "values are laid out across (distributed) memories. Along these lines,\n",
    "we use the following terminology:\n",
    "\n",
    "* **Layout**: how an array's values are laid out within each memory in\n",
    "    which they reside (e.g., in the memory of a single device\n",
    "    memory). A typical layout specification is a minor-to-major order\n",
    "    listing of array dimensions.\n",
    "* **Sharding**: how an array's values are distributed *across*\n",
    "    different memory spaces, such as multiple device memories\n",
    "    (e.g. described by sharding some dimensions and replicating\n",
    "    others).\n",
    "* **Format**: the pairing of **layout** and **sharding**,\n",
    "    providing a complete picture of an array's memory placement.\n",
    "\n",
    "## Types\n",
    "\n",
    "There are two Python types that come up when controlling array\n",
    "layouts: `Layout` and `Format`.\n",
    "\n",
    "* The `Layout` class is used to define the in-memory\n",
    "  layout of an array. It has the following key attributes:\n",
    "\n",
    "  * `major_to_minor`: A tuple of integers specifying the dimension\n",
    "    ordering in memory. For example, for a 2-dimensional array, `(0, 1)`\n",
    "    indicates row-major layout and `(1, 0)` indicates column-major.\n",
    "\n",
    "  * `_tiling`: An intentionally hidden, highly experimental, optional\n",
    "    attribute to specify a tiled layout.\n",
    "\n",
    "  * `AUTO`: A special, static sentinel object that can be used with\n",
    "    `jax.jit` to request that the compiler automatically determine\n",
    "    a good layout for a compiled function's input or output arrays.\n",
    "\n",
    "* The `Format` class carries both a `Layout` and a `Sharding`, with\n",
    "  either one taking on a default value when it is not specified.\n",
    "  When the layout is explicitly specified, the sharding must be\n",
    "  as well.\n",
    "\n",
    "JAX API functions, such as `jax.jit` and `jax.device_put`, accept\n",
    "`Sharding`s for sharding control or `Format`s for additional layout\n",
    "control. They typically do not accept `Layout` instances directly.\n",
    "\n",
    "## Specifying and reading layouts\n",
    "\n",
    "By passing `Format` objects to `jax.jit` in place of shardings (in the\n",
    "`in_shardings` and `out_shardings` arguments), you can guide the\n",
    "compiler's layout decisions. Similarly you can pass `Format`s instead\n",
    "of `Sharding`s to `jax.device_put` to control the layout of the\n",
    "resulting array.\n",
    "\n",
    "Let's see an example that uses both explicit and automatic layouts (as\n",
    "in `Layout.AUTO`). Imagine we have two compiled functions, `init_fn`\n",
    "and `apply_fn`. Say we expect `init_fn` to be called roughly once, but\n",
    "`apply_fn` to be called on the output of `init_fn` many times, so that\n",
    "we care much more about the performance of `apply_fn`. We may want to\n",
    "have the compiler choose a good layout for `apply_fn` and constrain\n",
    "`init_fn` to produce arrays of such layout. We can do this as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax, jax.numpy as jnp\n",
    "from jax.experimental.layout import Layout, Format\n",
    "from jax.sharding import SingleDeviceSharding\n",
    "import numpy as np\n",
    "\n",
    "def init_fn(x, y):\n",
    "  return x * 2, y * 3\n",
    "\n",
    "def apply_fn(x, y):\n",
    "  return x[0, :], y[:, 0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since `apply_fn` reads a contiguous column of its second argument `y`,\n",
    "it makes sense to lay it out in column-major order (where columns are\n",
    "stored contiguously). Using `Layout.AUTO`, we can ask the compiler to\n",
    "infer good input layouts and see that it indeed chooses to request the\n",
    "second argument in column-major layout."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "shape = (4 * 128, 8 * 128)\n",
    "duck = jax.ShapeDtypeStruct(shape, jnp.float32)\n",
    "\n",
    "# Compile the `apply` function with layouts inferred automatically\n",
    "apply_exe = jax.jit(\n",
    "    apply_fn,\n",
    "    in_shardings=Format(Layout.AUTO),\n",
    "    out_shardings=Format(Layout.AUTO),\n",
    ").trace(duck, duck).lower().compile()\n",
    "\n",
    "# Read back the inferred input layout\n",
    "arg_formats, kwarg_formats = apply_exe.input_formats\n",
    "assert len(kwarg_formats) == 0\n",
    "assert arg_formats[0].layout.major_to_minor == (0, 1)\n",
    "assert arg_formats[1].layout.major_to_minor == (1, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can then compile `init_fn` to explicitly match this layout in its\n",
    "outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_exe = jax.jit(init_fn, out_shardings=arg_formats).trace(\n",
    "    duck, duck).lower().compile()\n",
    "\n",
    "assert init_exe.output_formats == arg_formats"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally we can see how the compiled `apply_fn` behaves when called\n",
    "with differently laid out input arrays. The behavior varies with\n",
    "whether inputs are\n",
    "[committed](https://docs.jax.dev/en/latest/faq.html#controlling-data-and-computation-placement-on-devices). As\n",
    "the following test demonstrates, if the argument arrays are committed,\n",
    "then the pre-compiled `apply_fn` requires they match the layout\n",
    "determined by the compiler above. Meanwhile it accepts uncommitted\n",
    "arrays of any layout (including, of course, the inferred layout). In\n",
    "this case, the arrays may be relaid out prior to invoking the compiled\n",
    "computation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-- uncommitted with mismatched layout:\n",
      "x major_to_minor = (0, 1)\n",
      "y major_to_minor = (0, 1)\n",
      "-> `apply` called successfully\n",
      "\n",
      "-- uncommitted with matching layout:\n",
      "x major_to_minor = (0, 1)\n",
      "y major_to_minor = (1, 0)\n",
      "-> `apply` called successfully\n",
      "\n",
      "-- committed with matching layout:\n",
      "x major_to_minor = (0, 1)\n",
      "y major_to_minor = (1, 0)\n",
      "-> `apply` called successfully\n",
      "\n",
      "-- committed with mismatched layout:\n",
      "x major_to_minor = (0, 1)\n",
      "y major_to_minor = (0, 1)\n",
      "-> error: mismatched input layouts\n",
      "\n"
     ]
    }
   ],
   "source": [
    "def test(x, y, msg):\n",
    "  print(f'-- {msg}:')\n",
    "  print('x major_to_minor =', x.format.layout.major_to_minor)\n",
    "  print('y major_to_minor =', y.format.layout.major_to_minor)\n",
    "  try:\n",
    "    apply_exe(x, y)\n",
    "    print('-> `apply` called successfully')\n",
    "  except ValueError as e:\n",
    "    assert 'does not match' in str(e)\n",
    "    print('-> error: mismatched input layouts')\n",
    "  print()\n",
    "\n",
    "dev = jax.devices()[0]\n",
    "\n",
    "x1 = y1 = jnp.ones(shape)\n",
    "test(x1, y1, 'uncommitted with mismatched layout')\n",
    "\n",
    "x2, y2 = init_exe(x1, y1)\n",
    "test(x2, y2, 'uncommitted with matching layout')\n",
    "\n",
    "x3 = jnp.ones(shape)\n",
    "y3 = jax.device_put(np.ones(shape), Format(Layout(major_to_minor=(1, 0)),\n",
    "                                           SingleDeviceSharding(dev)))\n",
    "test(x3, y3, 'committed with matching layout')\n",
    "\n",
    "x4 = jnp.ones(shape)\n",
    "y4 = jax.device_put(np.ones(shape), Format(Layout(major_to_minor=(0, 1)),\n",
    "                                           SingleDeviceSharding(dev)))\n",
    "test(x4, y4, 'committed with mismatched layout')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Constraining intermediate layouts\n",
    "\n",
    "We can also enforce a specific layout on an intermediate value within\n",
    "a JIT-compiled function using `with_layout_constraint`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jax.experimental.layout import with_layout_constraint\n",
    "\n",
    "@jax.jit\n",
    "def f(x):\n",
    "  y = x.T\n",
    "  # Enforce a specific layout on `y`\n",
    "  y = with_layout_constraint(y, Layout(major_to_minor=(0, 1)))\n",
    "  return y * 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is analogous to\n",
    "[`jax.lax.with_sharding_constraint`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.with_sharding_constraint.html),\n",
    "for constraining layouts rather than shardings."
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "formats": "ipynb,md:myst"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
